{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.utils.data as data\n",
    "\n",
    "from torch import Tensor\n",
    "from torch.distributions import Bernoulli, Distribution, Independent, Normal\n",
    "from torchvision.datasets import MNIST\n",
    "from torchvision.transforms.functional import to_pil_image, to_tensor\n",
    "from tqdm import tqdm\n",
    "\n",
    "import zuko\n",
    "\n",
    "_ = torch.random.manual_seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "trainset = MNIST(root=\"~/data\", download=True, train=True, transform=to_tensor)\n",
    "trainloader = data.DataLoader(trainset, batch_size=256, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcAcABAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APAACzBVBJJwAO9dnp/wm8damu6Dw5dRjGf9IKw/+hkVPffCnWNJa7XVNV0Kxa1hErrNe/M2cnYqgElsAHpjkc1wlFen+F/gpq+v+HRq95fwaSs5C2cVyhzMSQFzz8oJPHUn05Feb3lnc6feS2l5byW9zC22SKVSrKfQg0RWdxPbXFzFEzQ24UyuOibjgZ+pqCtGw0W51HTdSv4WjEOnxpJNuJydzBQBgHnnPOOhrOorZ8LeGNR8X69Do+mCL7RKC26VtqIoGSScE4+gJqrrekz6Drd7pN08bz2czQyNESVLKcHBIBx+FUKtWGmahqs/kadY3N5NjPl28TSN+Sgmi80y/wBOYre2NzbMHKETxMhDDqOR1HpVWlALEAAkngAV2mm/Cbxpqlgb2LRpIofLMiC4dY3kwM4VCdxJ7cVxZBUkEEEcEGkrrvBngY+MNO166XU0tW0q1NyIjEXM3DHGcjaPl689elcjRWloeg6n4k1NNO0i0e6unG7YmBgDqSTwBz1r0e1/Z58ZXFrHNJc6RbO4yYZp3Lr9dqEfkTXnGv6DqHhrW7nSNTiEd1bthgDkEdQwPcEYIrNrV0Pw7qHiJr9bBEb7DZyXs5dsYjQc49TyAB7+nNZVFFFdH4q8IXHhKHSkv7lDqF7bm4ltFHNupOEBbPJIyegxjvXOV6J4A+E9z490i51C31uztRDJ5XksjO+7gjcOMA54IzXK+K/DN/4Q8RXOjagF82E5V1+7Ih+6w9iKxaKKunR9TXTTqLaddixDbTcmBvKB6Y3Yxn8a6Dwj8N/Efja1nutIt4TbQSeW8s0oQBsZx6njH5it6X4GeLEdkiuNHuH52LFejL/QED9cVwesaLqPh/U5dO1W1e1u4sb4nIJGeR04NUKUAsQACSeABXpeq+CtE8D+Cku/EfmXfiLVLcmzsUcxraZxiRwCGJHPBG3III4zXmdFbOk+Gb/WdF1jVbVofI0mNJbhXch2VmwNoxzjBJzisaiir2n6LqurNt03TL29OSMW0DScgZP3Qe1VJoZbaeSCeN4po2KPG6lWVgcEEHoQe1MBwcivXvC3jnXPFvhnU/Cd14jltdYuGjl0+7mn8tX24BhLgZXIGRzycjvz594ws/EFl4luYvE3ntqg2rJJMdxkCgKrBv4hhRz3rCr0T4e6Np+m6VeePdfhebT9KlVLS2HH2q5P3Rn0UlSf64IpviP4x+KddmvkiuzaWVzJG6QIATEEOVVWx6jJPU/TirPx30+3sPidcvATm6t47iUE5w5yp/RQfxqDw1pFxefB/wASvYWV1eXlzqNrCY7eMuVRAX3YHPUkflWJB8N/GE8ImbQrm3hPPmXZW3UcZ6yFRXS67psfhD4N2tgbu1ubzXdSad5bR1kTyoBt2bx97DMDxxkn8fMaACSABknoBXtMa/8ACoPhs7TxmPxbr6MIwUw9pDgA89iOuPUjj5a8XJLEkkknuau6Lpj61rlhpccixveXEcCu3RSzBcn8692l8KeJbQHwd4FsLnSNNSTbqGu3mYZLuQDB2n7xTsNg5+mSeM+Mni5NR1G18K2UkstjoeYXnmkLvPMBtZix5OMEfXPtXl1dP4A1bUdF8W295pOjJq9+EdYbdo2chiPvALzkfyz9a9w0HTfG8/jeHxn40a3isdNtJWW2tZN/kEpyvlqSd2Dz1J49q+bbmQTXU0iggO7MM9eTUVex/s8TfaPEut6NMiPZXmms8yEcttdVxn0xI36V5PqkSQateQxLtjjndVX0AYgCqldR4C0PxLrviHyfC1y9rfRxF2nW4MOxOh+Yc45HAzXufg7wDrfguW48WeK/EN5ePaRSyvZ20kkwYYOSST8x74x1714F4z8Qt4q8YanrJ3iO4mJhV+qxjhAeTztA/GsGvTPgmvn+JNcsfKMgvNEuYSF68len8q8zorrPAfgLU/HOsLBbRmOwhZTeXbHCxJ3wT1YjOB/TmvQb7w0t18StW8Ua9pb6L4b0Pyn2Sw4FyIwqxouPlJbAyB6gd68o8Ta7P4m8R3+sXC7HupS4j3lhGvZQT2AwP8Kya9u+Alvp9pNdatP4qgsmDiObTJNiCVQCUZmfrySRt6Y681z3x7h8v4pXL+Wiia2hcMrEl/l25Poflxx2APevMqs6dYzapqlpp9uVE11MkEe9sLuZgoyewya7y9vLb4X6jdaNa6Ra3+uQy/vNT1C33Kg28eTG3Tkn5jndgHFc94g8feKfFEbRatrNzNbtjdbqfLiOMEZRcKeQDyK55ZpUQokjqpOSAxAJr2L4KeGfCNzNFrus61af2lbz/wCjafLMsexgQVc5ILHOcAcVkfFvwp41g1y68ReIIYp7WV9iXFo26KJc4RMYBUc45HJ7k15jXpfwY0Swu/ElzrutRxf2Ro8BnlknXMYc8LnPGRyRn0q9Z+Bx4qmm8WeL9Xn0631S5YWcYjMk9wSRs2LgkjGRjHYY4rgPFnh5/Cvii/0V7qO6NrJtE0YwGBGQSOxwRkZOD3NYtem/BmCO6u/FltOgeGTQLjch6H5krzKivVfAE3gLw54Mu9f8RR2eravJI0VvpsgDsoA4yp4GSM78HAwOuRXafDjxr4h1O51HxFqC2ukeC7GJi8FtbokYcAYVfl3MecnB64HoK8S8ZeIG8U+L9T1ooqC5myiqMYRQFX8doGfeqOlaNqWu3q2el2M95cMQNkKFsZOMn0HueK9S0v4Q6f4d09dZ+Iurx6ZbZ+SyhcNLIfTIz+S5+ornvib4603xcdJsNHspoNO0mEwQy3LBpZVwoGepGAvcnOcmuAr2Xwh4Xfx/8Hl0HTNUtYdRs9Ue6e3lb7yFNoJwMjrwenWsifTvCXw2keWTUYfEfiSLmCGJM2ls/wDec5+cgjpkdsivPdU1S91rU59S1G4a4u7ht0srAAsfoOB9BTbLU7/TWdrC+ubVnADGCVkLfXB5pt3f3moS+be3c9zJ03zSFz+Zru/iMJV8K+Al3Yt/7GBRB0DlvmP1Py/lXnlei/BXw1B4h8fRS3gR7TTYmu5I2B+cjhOnoxDf8Bx3rE+Ini2Xxl4yvNS81ntFYw2aldu2EE7ePU5JPua5WlVirBlJBByCO1eofCDWdR1n4s6Q2r6pcXrJHN5f2yZpTnymwFLE4Pf8K80vJZZr64lmJaV5GZyRglicmoa9r0C/8E6F8O7eOz8Upp9/fRY1WSG0aW8fI5hjJ4RR0z0OM1zd18VLrRRDpvgeNtK0q3l83fKqvPdv3eYnIOfQcdPQY5zxJ4nsfEFtCYvDen6df799zdWpZfOOMcJnag7nA6+nfm69C+CmvLoXxLshIQsN+jWTse27BX/x5Vql8W9Lh0j4oa3b24URvKs4VVwAZEDkfmxriq9S0v4M67qnhnS9b8P6na3M15GzTIk3liAEDCbu7csGHGK6HTNN0/4U+GtZm8Sa6txrep2UlpDp9pL5oQMvBYeue/QDpnNeGUV3vwavJrT4p6QIo/M88yQuvP3SjZP4Yz+FcPdBRdzBEMaiRsIWDbRnpkcH6ioq6zw9f+IfEX9l+CbDUDbWU85HlowjVyx3FnPG8gDgH+6ABmus8f8AxU1Qyz+GNC1J20a3thYyzSKGkumAw77mG4Z6cHtnvXk9aGh21hea5ZW2qXT2tjLMqzzoASinvzx+J6da9U0T4HJeaiNRl8SaXP4bhKzyTxSne0P3sMMAIdo5JPGT6VxvxO8WW/i/xlNeWC7NNt41trMbNv7te+OwJJIHYECuNrt9C1zwUfCa6Nr+lalDcrO0xvtMEJkm/uhjKMqBkjCkA/UmqHjrxavi3WLaWC3eCysrVLO1SVt8hjToXbuxzk1y9S20P2m7hgMiRCR1TzH+6uTjJx2Fema58DvE9triWmi2jahYMkey/Msao5Kgs2M5VQSR34Ge9dx4x1FfAPwZbwjq2sJqevXSGIIHL+WhYMc55CheBnGTjAx0+dq9E8B654dHhHXPC/iLU5tMhvp4LhLmO3aZTsbJVlXJIOBjj3rqPH3xR8PNef2h4QNzLrhj+zLqUyMFt4MYIhRvuluedoPJ5rxWSR5pXlldnkclmZjksT1JPc02vSvgyv2jVvEtgrFZrvQLqGMj+8dtea0Vs+FfDd74s8RWmj2Ktvmcb5AuREmRuc+wzXZfEfxbaR6RZ+AdAYPo+kMFluQADdTDO5sY4AZm+pOfSvNKvabrWq6LJJJpWp3li8g2u1rO0RYehKkZqG8vrzUJvOvbqe5l6b5pC7evU/U1XooooooqxPf3l1BbwXF3PNDbKUgjkkLLEpOSFB+6M+lV6t6fqmoaTO0+m311ZzMpQyW0zRsVPUZUg4qpRRVnT7+60vULe/spmhureQSRSL1Vh0NWNd1mXX9ZuNUuLe2gmuCGkS2QohbABbBJwTjJ9yelZ1FFFFFFFXrPWdU062mtrLUry2gm/wBbFBOyLJ/vAHB/GqJJJyTk0UVNa3dzY3KXNpcS286cpLC5Rl4xwRyKhJJJJOSepoooooooooooorRg1/WbWEQ2+rX8UQAUJHcuq4HQYBqhJI8shkkdnduSzHJNNooooqW3urizm862nlglwV3xOVOCMEZHqCRUVFbXh/xRqPhmPURpvkpJfW5tnmZMvGpPJQ/wkjI+h9cEYxJYkkkk8kmkr//Z",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAcAAAAAcCAAAAADTxTBPAAALaUlEQVR4Ae1ZeVhWVRo/poC5EppAiui4QJgFmiktksvkmKWZoeL28DiNmmmZgxpOklu54YjLUGJJT2aK2YambZCamssQueSSioIoghug+Cn3Peeb99773e/ec8+5aE+jzzR5/vjO+/5+73vuved3tns/Qm6X2z1wuwdu98Afsgda5+X/IZ/7/+WhF5fAZ7f0Wdqn0/R2t/SKDhdb6N4X6kD9F+CsbFsjEROPfDhxoq8N/a1u4DYKPzX4ra3w+dUDAgKSZn96z4du1+s8o3qRFwDgvIibSLczYabDWa9RdwwH3IhTN/j5yX6SwGbnKfSQ4CrUus0LbqqWTyQd7hOzzSHNAi9wLbV4aI4sZ2rpyqNVeHX+Ovify65sWIZlxoNOca0zgU7oImWrrS5vIiWqBJu2HJa2Rnv0/LW0/PvHheCHTlK4WAzRlo7p3JePenW9g4DxlxXamY+9rtf8rZ9wwCySxNX+1EnANvNOFADFNCzL6wmpDdnpIAG0AbNd5f15KOCMJuCFJ3jY2Zurxes/sHdyc2lkNACNkzKkViF7Xs5UgUZd0MTDH2Xos892EoSo9egJ7JldzwGdbLaStMK00bpj6d5wDvA6rytSATsu2UvpKwOWdPQGeo3wty8DPbEPzshaTHESMFNVziMgPOJtzDAaMhZp2E71JvjGTo26zE6gHPPtOOeHhk8rKUnXoKMYfDYLS2pWDpq9uDiP0zqP0j4yQsWymaWTxaC/J75P6c82POCIJuD2Da4yG6O7K/Seic+iq0z+KC9gY/a+yVmt7ueUfaE1rYhmDzgDNGsv9vhqO1X/7Yt4vUOhrQAetXOE+H/nJOBYgNMz35g5M0sVUiagO1JsTkc6fxWgGnHnDj8ghOSyvSjEnwTcBLqnXmTYgQc1pMWTLVoE61xdVN62IuvEDFjX2Ey3Wf2cOhLjYl5co6hPV3nAlvTMOy9SmlObtEmzEZrbHido9nha+EAft6W383gBN7IpslzyaKGiDBOYGtHlkN3Fp84GoAl2Ml69xcMhRC7gPccovBZqz1H9GiEhQWpdrwBgrZ9qcQVnYDQHWJxDVBsq+922bUENee5H1I/da4nmzXd2Il2WOlwco4MYc3XggzVv+5WjrSSwBwphVz0DgI8J3lRQUAZ0l9o7kM9z+NDV0uggO+jx1fPLujq9Eu8mhF5qZwTdX8ELuJ11MiiuXkZpFgdoDoq0EbepIXgr2CxfvgA4uqopIU9LZyCZgqvkGD6D92IvAaTwkOqhgI5pP0J3jIgsl25MQXtQoY/EBjWkQRo7t7tfGN6uvfi+fYWxKDuKfh8Ks5pIcA8U4mYjJWz345pyYQ3CupzAzhMj5tHsO0QUkdYrafFPz+kUhZVGzKuMEzCwiIUYlLVuSJWzXa2AZs+ksAj1IwcBxL3gnqkPN1KjnpcLSK4j4EB1CVUbtxX/i2yBDTLcGcp+HEa1V8E2HwMy68FzKAo4zgQ4K4UurMMBhtN1OWPXRorTkvgnU5ioRr2cnGwEW+sQxsZafY/9NT5VxZgO6KXiAG8oRtTOptKTll8mlPZo4BkxFL43MtPZq4ap1ivYIX+r77Gb5VAlScCTqOuzOwmp2buCThNIL/Cug4D4pjDGG2Q3Bu934ZP+G1sXSqaTgCFnXDEYvRQKhBwSfqAS5XPYA2tNy+vdR6IRNvOQgklXn5QMiLrfuSkuVuNfwZOMW7IVygV8ohwgT9/Y8bA2X7xRQlqU5b83pprA4IlXfTq9cAJ2M1BSr3/mVTbY61qMUQr9sr7F10z/M9pHiJY7ATJq20ndfylxcuIu2FJdxjrNwGZTNm/ejCRcGOEZb3y2k4Btj4A6NROugWTp6uvS9GOL+LZ07026Si4fIcl62q4pbe2JTwE93opEfgJQfoD+EGqniVzArwC2aB1+16BSjyVk9i2ldJKwf26n2Wakm241nHQWi+YDUQkpqWWXStaVQbjBWOpnSpXNgRZfNxsBNG00aVsZBeVpgUSgVod1ODgpLWwhY52W0LZ5qJ32GvG5NI1kSo93NeLxUjsm+wXvrFwuy3sJdzIsH8k4N+stg1Xs4S9KdAkhWdsPvHF1x0LhdG1f+iAqhh6QCOiWLaH9crKDtDYSAfbolrdJr9H2a0pTG3tdzXjqCowzEQpLDCeVns/NzaWssnT7gsFNfIorDcJSN8PDdbrF95j+RVo/F5yEIpEkPg+dhEuFa3DJKJrgK+EdBTyuvgtpX2KelKWhgKUSfAiKfhhgh/xmMKFnXNzQUrmAO2nBnyVN6lDTdj2Xqbsn+447W/QESCKBuC8t8bvvYKm3O81m5DPQ4J92wdXRhiPU/kOBfsOjsXDaOyn9ZtGv63jpSZ+rZXgnDRjBjnoJ03hLUZQw0/VaHc/Sw3MjgjZpS5cX1Q3f3gBTHiEB6pcYGOBnY1UXNVojgUnoPx68D8sCALmAr8gEHKC4irpE4rkHl4OT8hlPSLWp7IgwVTr6koCptOxe2a0Y2OAdqoLakcWAJgEQsk3dl6IBkg3UUqOAMRbXZuIWMcIGce41eu1xDoiF44bvNwPyexiOrc5gc2wIupHHFGWtCBtIZzcVj1s+swDW+5O7d1PXtI8BvuwaJZzE1W0uwmhErOs7CdiPVQgqkOxjw7GJiK3qmvC+2JiO+DF2oAlPBuecHYKvJpQ+zOM2r8YmFDDNCs6in5DIU3Qcwc8x46yEYaOAtmsZDNZv4uiVTQk95P7pGynN5SY8iYWFnvzIlfCxpSnezJC9JJcoytY6fJzV60Hhbquv2tVnQ9nou0iHHXCoC6n3lxVlYA4gI/ZfKGCK4Yh1fycB+7ArrYXwl0NUqPNF6B8RUU9gPQB+30ywcUUVYxCZSb9yTNIT5qOAXO4s7MXIk7C8oGRnY2nnVCWg70YKL/ICmfcVtuQU7iGVG0xEtfrTfB0Yf4E6jlBCpAJSRYnjW+M9iYAvQPnAgJ5rLkOS1rEkbv36VnwSIWNlAvr0ulOPG46bp3wJJQdYqr0t3a+/hP4iYRp8PkhDg0vF14jEClTmMDveTpJHSHBSfx2v/i1jlY9ZY3DhjB5Vikt2scNdooBOS3mtEUBX+FtbM+2g8cfUM8BO+7EqFq4tigyJzcynx1fp252ZY7Ey3MMsnm6m43QX1ywzTDYDi6Ai5xAK9Fp1M06wfsGziu0ZH9sImuIBQ/A76qUuQooGpJTVlBOJUCRbtT5gB2NakvZxPzI2T8hMWFlcXLI+THqf+O2mvnalwDko8x7uou3LtRMcHmI42HRQwLGmZ7XqrgZ4ST7/Arv+rMq3va9Ax2JvnsIPJrB1urUtu53B4u1QZCF1JQvPbYkaLZmBuXglyExoWcMSJ5if4vi1CYhHnsVvYtmNW9m3/YQMHUgpFR5PI0KPKdNkKdHbGMtbX8YY/dnhdVWWhdhq/OMDF4Q7X8dcd3kMH9UrC+/x3XE21Izx3e8k4L34idiMs1gBH2n/Rnz/jGcVslCkyQ/qAQ2KF1pB0c4QP7s/rtCjYqAFaeumwh5Yd+iCxEBfS5DMxJO4REBVeiROL3UcNCnsWVlz5Bd4T4qT5BdwAmE5J6cd0b9hTk52do6aW97NMcyB2M0ypUz4MpC8NhLScW2BOvsuvyEfZcFTUcD5wkZku0SGe6kNIdcVEDuukz3phvzQfYKAUe9q8h3OXdTWuYnTrnApORnkwhLiN2HCSsYutpOmOYPNP1SlU0vl3I7OYQ7MMpYtZVYCjJYRs1G9/bNm+su4G8bixRkYtPk6M5DEQ1bEDV/hOoF+I8/C2pFBVUat3lPVllxl6q8k/QamJWzZkpY2MOpXJqrhzX4YJctqsw5Snd8gZBk3Hav3JayRT/ubfunf4QXmwLH/Mf3wX8jFVb2R/w47+WbecjfJP28383q3274FPfAf6yvFvPLt1QQAAAAASUVORK5CYII=",
      "text/plain": [
       "<PIL.Image.Image image mode=L size=448x28>"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = [trainset[i][0] for i in range(16)]\n",
    "x = torch.cat(x, dim=-1)\n",
    "\n",
    "to_pil_image(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class ELBO(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        encoder: zuko.lazy.LazyDistribution,\n",
    "        decoder: zuko.lazy.LazyDistribution,\n",
    "        prior: zuko.lazy.LazyDistribution,\n",
    "    ):\n",
    "        super().__init__()\n",
    "\n",
    "        self.encoder = encoder\n",
    "        self.decoder = decoder\n",
    "        self.prior = prior\n",
    "\n",
    "    def forward(self, x: Tensor) -> Tensor:\n",
    "        q = self.encoder(x)\n",
    "        z = q.rsample()\n",
    "\n",
    "        return self.decoder(z).log_prob(x) + self.prior().log_prob(z) - q.log_prob(z)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GaussianModel(zuko.lazy.LazyDistribution):\n",
    "    def __init__(self, features: int, context: int, hidden: int = 1024):\n",
    "        super().__init__()\n",
    "\n",
    "        self.hyper = nn.Sequential(\n",
    "            nn.Linear(context, hidden),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(hidden, hidden),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(hidden, 2 * features),\n",
    "        )\n",
    "\n",
    "    def forward(self, c: Tensor) -> Distribution:\n",
    "        phi = self.hyper(c)\n",
    "        mu, log_sigma = phi.chunk(2, dim=-1)\n",
    "\n",
    "        return Independent(Normal(mu, log_sigma.exp()), 1)\n",
    "\n",
    "\n",
    "class BernoulliModel(zuko.lazy.LazyDistribution):\n",
    "    def __init__(self, features: int, context: int, hidden: int = 1024):\n",
    "        super().__init__()\n",
    "\n",
    "        self.hyper = nn.Sequential(\n",
    "            nn.Linear(context, hidden),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(hidden, hidden),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(hidden, features),\n",
    "        )\n",
    "\n",
    "    def forward(self, c: Tensor) -> Distribution:\n",
    "        phi = self.hyper(c)\n",
    "        rho = torch.sigmoid(phi)\n",
    "\n",
    "        return Independent(Bernoulli(rho), 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BernoulliModel(\n",
      "  (hyper): Sequential(\n",
      "    (0): Linear(in_features=16, out_features=256, bias=True)\n",
      "    (1): ReLU()\n",
      "    (2): Linear(in_features=256, out_features=256, bias=True)\n",
      "    (3): ReLU()\n",
      "    (4): Linear(in_features=256, out_features=784, bias=True)\n",
      "  )\n",
      ")\n",
      "271,632\n",
      "---\n",
      "GaussianModel(\n",
      "  (hyper): Sequential(\n",
      "    (0): Linear(in_features=784, out_features=256, bias=True)\n",
      "    (1): ReLU()\n",
      "    (2): Linear(in_features=256, out_features=256, bias=True)\n",
      "    (3): ReLU()\n",
      "    (4): Linear(in_features=256, out_features=32, bias=True)\n",
      "  )\n",
      ")\n",
      "274,976\n",
      "---\n",
      "MAF(\n",
      "  (transform): LazyComposedTransform(\n",
      "    (0): MaskedAutoregressiveTransform(\n",
      "      (base): MonotonicAffineTransform()\n",
      "      (order): [0, 1, 2, 3, 4, ..., 11, 12, 13, 14, 15]\n",
      "      (hyper): MaskedMLP(\n",
      "        (0): MaskedLinear(in_features=16, out_features=256, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): MaskedLinear(in_features=256, out_features=256, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): MaskedLinear(in_features=256, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (1): MaskedAutoregressiveTransform(\n",
      "      (base): MonotonicAffineTransform()\n",
      "      (order): [15, 14, 13, 12, 11, ..., 4, 3, 2, 1, 0]\n",
      "      (hyper): MaskedMLP(\n",
      "        (0): MaskedLinear(in_features=16, out_features=256, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): MaskedLinear(in_features=256, out_features=256, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): MaskedLinear(in_features=256, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (2): MaskedAutoregressiveTransform(\n",
      "      (base): MonotonicAffineTransform()\n",
      "      (order): [0, 1, 2, 3, 4, ..., 11, 12, 13, 14, 15]\n",
      "      (hyper): MaskedMLP(\n",
      "        (0): MaskedLinear(in_features=16, out_features=256, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): MaskedLinear(in_features=256, out_features=256, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): MaskedLinear(in_features=256, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "  )\n",
      "  (base): UnconditionalDistribution(DiagNormal(loc: torch.Size([16]), scale: torch.Size([16])))\n",
      ")\n",
      "235,104\n",
      "---\n"
     ]
    }
   ],
   "source": [
    "n_features = 16\n",
    "encoder = GaussianModel(n_features, 784, hidden=256)\n",
    "decoder = BernoulliModel(784, n_features, hidden=256)\n",
    "\n",
    "prior = zuko.flows.MAF(\n",
    "    features=n_features,\n",
    "    transforms=3,\n",
    "    hidden_features=(256, 256),\n",
    ")\n",
    "\n",
    "\n",
    "print(decoder)\n",
    "n_params = sum(p.numel() for p in decoder.parameters())\n",
    "# print with thousands separator\n",
    "print(f\"{n_params:,}\")\n",
    "print(\"---\")\n",
    "print(encoder)\n",
    "n_params = sum(p.numel() for p in encoder.parameters())\n",
    "# print with thousands separator\n",
    "print(f\"{n_params:,}\")\n",
    "print(\"---\")\n",
    "print(prior)\n",
    "n_params = sum(p.numel() for p in prior.parameters())\n",
    "# print with thousands separator\n",
    "print(f\"{n_params:,}\")\n",
    "print(\"---\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 64/64 [06:17<00:00,  5.90s/it, loss=72.7]\n"
     ]
    }
   ],
   "source": [
    "elbo = ELBO(encoder, decoder, prior).cuda()\n",
    "optimizer = torch.optim.Adam(elbo.parameters(), lr=1e-3)\n",
    "\n",
    "for epoch in (bar := tqdm(range(64))):\n",
    "    losses = []\n",
    "\n",
    "    for x, _ in trainloader:\n",
    "        x = x.round().flatten(-3).cuda()\n",
    "        loss = -elbo(x).mean()\n",
    "        loss.backward()\n",
    "\n",
    "        optimizer.step()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        losses.append(loss.detach())\n",
    "\n",
    "    losses = torch.stack(losses)\n",
    "\n",
    "    bar.set_postfix(loss=losses.mean().item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 2.0948e+00, -2.2203e-01, -9.5326e+00, -3.6423e+00,  3.6458e+00,\n",
      "          2.2593e-01,  1.5169e+00,  7.4738e+00, -4.9242e+00,  4.2030e-01,\n",
      "         -2.9945e+00, -2.3906e+00, -2.0135e+00, -2.6839e+00, -2.7722e+00,\n",
      "         -4.5534e+00],\n",
      "        [-2.8519e+00, -2.9411e+00, -1.4257e+00, -1.2589e+01, -6.7205e-02,\n",
      "          4.7825e+00,  4.3651e+00, -5.5354e+00, -2.6279e+00,  1.7060e+00,\n",
      "          4.9544e+00,  1.1541e+00,  6.0123e-01, -4.0677e+00,  2.2275e+00,\n",
      "          2.1673e+00],\n",
      "        [ 4.4487e+00,  6.9984e+00, -5.9975e+00,  1.1340e+01, -2.5034e+00,\n",
      "         -5.0558e+00, -1.2098e+00,  4.1107e+00, -6.5148e+00,  3.1279e+00,\n",
      "          4.5749e+00,  1.0275e+01, -9.4106e+00,  4.3427e-01,  1.7028e+00,\n",
      "         -3.5465e+00],\n",
      "        [-1.5128e+00, -3.7530e+00, -1.0134e+01, -1.8204e+00,  2.7784e+00,\n",
      "          3.7442e+00,  6.1496e+00,  5.4763e+00, -9.7073e-01,  4.0887e+00,\n",
      "         -2.8771e+00, -4.4852e+00,  2.4449e+00,  6.8069e-01, -1.9139e+00,\n",
      "          1.1641e-02],\n",
      "        [-2.9815e+00, -3.1523e+00,  1.8994e+00, -9.2917e+00,  5.5131e+00,\n",
      "          3.2415e+00,  6.4441e+00, -4.2406e+00, -1.9622e+00,  4.5102e+00,\n",
      "          1.4325e+00, -4.4551e+00,  4.1225e+00,  3.9152e-01,  2.9209e+00,\n",
      "          3.8125e+00],\n",
      "        [ 8.5755e+00, -4.1220e-03, -9.9238e-01, -1.0660e+01, -6.5195e+00,\n",
      "         -1.2357e+00, -6.7705e-02, -1.3971e+00,  2.8105e+00,  1.5272e+00,\n",
      "          8.3772e-01, -3.4890e-01, -3.9543e+00, -3.0899e-01, -2.0501e+00,\n",
      "         -3.0576e+00],\n",
      "        [-2.9376e+00, -9.4083e-01, -5.6199e+00, -2.3707e+00, -2.8373e+00,\n",
      "          2.8221e+00,  1.1896e+01, -2.4755e-02,  3.9772e-01,  7.1074e+00,\n",
      "         -2.9341e+00, -1.3623e+00,  6.0669e-01,  8.8308e-01, -2.3005e+00,\n",
      "          1.1948e+00],\n",
      "        [-4.1739e+00,  1.6660e+00, -5.0737e+00, -7.7492e+00, -3.0098e-01,\n",
      "         -2.1164e-01, -7.5466e-01, -1.5586e+00, -4.8754e+00,  2.9488e+00,\n",
      "          2.0603e+00,  5.6619e-01, -3.3756e+00, -2.2467e+00,  6.6043e+00,\n",
      "         -3.2417e+00],\n",
      "        [-6.6744e+00, -2.8065e+00,  1.1901e+00, -1.9729e+00,  4.6420e-01,\n",
      "         -2.3408e+00,  1.0276e+01, -2.8693e+00,  2.9224e+00, -7.3012e+00,\n",
      "         -6.0294e+00,  6.5244e+00, -4.1638e+00, -5.5633e+00,  5.2043e+00,\n",
      "          3.5723e+00],\n",
      "        [-6.5320e+00, -6.5026e-01, -2.7060e+00, -1.9325e+00, -1.9741e+00,\n",
      "          9.5200e+00,  1.4755e+00, -3.4872e+00,  4.9008e-01,  2.5770e+00,\n",
      "         -2.0502e-01,  1.5714e+00, -7.2517e-01,  3.6877e+00, -6.1762e-01,\n",
      "         -1.9826e+00],\n",
      "        [ 7.5951e-01,  3.8501e+00, -1.0923e+00,  4.5650e+00, -5.6751e+00,\n",
      "         -1.3243e+00,  3.1173e+00, -3.2989e+00, -2.2346e+00,  9.8362e-01,\n",
      "          2.8020e+00, -3.8469e+00, -5.2273e+00,  2.0790e+00, -1.3828e+00,\n",
      "         -6.6659e+00],\n",
      "        [-4.5603e+00,  2.7711e+00, -3.7077e+00,  2.5227e+00,  8.9348e-01,\n",
      "         -5.1554e+00, -5.1303e+00,  3.1563e+00, -3.9091e+00,  2.9484e+00,\n",
      "          1.2177e+00,  5.6518e+00, -5.4456e+00,  3.7657e+00,  6.8424e+00,\n",
      "         -2.8479e+00],\n",
      "        [-1.0015e-01, -1.1281e+00, -1.1024e+01,  4.2252e-01, -9.9055e-01,\n",
      "         -3.2478e-01,  7.8611e+00,  4.4904e+00, -4.0893e+00,  4.0988e+00,\n",
      "         -4.8009e+00,  2.9625e+00, -1.2395e+00,  7.1263e+00, -3.9094e+00,\n",
      "         -4.6008e+00],\n",
      "        [ 1.8952e+00,  6.7809e+00, -8.1615e-01, -2.2078e+00, -6.0371e+00,\n",
      "         -7.3572e+00,  1.1905e+00,  1.1581e+00,  2.1859e-01, -5.0795e+00,\n",
      "         -1.8859e+00, -2.3903e+00, -2.7011e+00, -4.0724e+00,  2.6256e+00,\n",
      "         -3.9604e+00],\n",
      "        [-1.8115e-02,  7.5609e+00, -1.0310e+01, -9.5932e-01, -1.7009e+00,\n",
      "         -3.3912e+00, -1.2947e+00,  3.1317e+00, -4.3950e+00,  2.2856e+00,\n",
      "          2.8459e+00,  9.6810e+00, -3.8574e+00,  5.3783e+00,  3.2217e+00,\n",
      "         -9.5469e-01],\n",
      "        [-1.7537e+00,  2.3684e+00, -9.3794e+00, -1.8861e+00, -1.0208e+00,\n",
      "          2.5571e+00, -6.9086e-01,  4.8350e+00, -4.8028e+00,  3.5592e+00,\n",
      "         -1.0633e+00, -2.1759e+00, -5.3362e+00, -1.7342e+00,  3.2121e+00,\n",
      "         -6.3396e+00],\n",
      "        [-3.5549e+00, -3.2155e+00,  1.1701e+00, -5.4587e+00,  3.9894e+00,\n",
      "          3.3543e+00,  6.9619e+00, -2.4630e+00, -5.7152e-01,  3.5754e+00,\n",
      "         -2.4762e+00, -3.8352e+00,  5.4903e+00,  7.5193e+00,  1.4957e+00,\n",
      "          4.8524e+00]], device='cuda:0')\n",
      "torch.Size([17, 16])\n",
      "Independent(Bernoulli(probs: torch.Size([17, 784])), 1)\n",
      "tensor(1., device='cuda:0')\n",
      "tensor(0., device='cuda:0')\n",
      "torch.Size([1, 17, 784])\n",
      "tensor(0.7059, device='cuda:0')\n",
      "tensor(0., device='cuda:0')\n",
      "torch.Size([28, 28])\n",
      "torch.Size([17, 784])\n",
      "tensor(1.0000, device='cuda:0', grad_fn=<MaxBackward1>)\n",
      "tensor(0., device='cuda:0', grad_fn=<MinBackward1>)\n",
      "torch.Size([28, 476])\n"
     ]
    },
    {
     "data": {
      "image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcAdwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+r+maHq2tNIulaXe3xjwXFrbvLtz0ztBxSanomq6LJHHqumXli8gJRbqBoiw9RuAzVGiu41D4aX+mfDS18ZXV9boly6+XZkHeY2+6wbpk9dvpznPFcPRW14a8Ka14u1IWOjWT3Eg5d/upGPVmPAH8+2a6rxV8GPFPhPRm1W4NneW8fM32N3dol/vEMo+X1I6fTmvO6fHFJK22ONnPXCjNaFh4c1zVAx0/RdRuwv3jb2ryY+uAao3NtcWdzJbXUEkE8bbXilQqyn0IPINRUqI0jqiKWdjhVAySfSuk8c+Ev+EK8QLoz6jFe3KW8clwY0KiKRhkpyecDBzxkEcCuapVUswVQSxOAB3rX1rwvq2gauNKvrcfbvKWVoIXEjICM4YLnBxzj05rHorR0XQtU8RaithpNnJd3LAtsTsB1JJ4A+tLrfh/VvDl8bPV7Ca0n6gSLww9VPRh7is2pLe3mu7iK3t4nmnlYJHHGpZnYnAAA6kmu51z4QeKfDnhWTX9SS1jhi2mWBZt0sYLBRnA29SOhNcFWl4f0221fXrOwvNRg062mkCyXU5wka9z/hnAz1IHNerX/wAJPCes2UsPgTxUupazAhla1nuI2EqDqF2qMHOOTkc84614wylGKsCGBwQRyDSUUoBPQVpN4d1qOwe/k0q8js0UO08kLKmDjB3EY5yPzrMq9pGkX+u6lFp+m27T3Mv3VXjAHUkngAdya1tQ0/QdFt3t/tq6rqiNhjbk/ZRg4IDcFvqOOO+a5utfw34Z1TxXqq6bpMAlnKl2LMFVFHViT0HNdJqvwwn0hGW48U+GftKcNbC/+cfgVrj9S0270i/lsr2IxTxn5lyD9CCOoNVKu6Zo+pa1cm30vT7q9mA3FLeJpCB6kAcD3re1T4Z+M9G04399oF1HbAZZ1KuUHqwUkqPciuUrofBXhW48ZeJ4NJhfykKtJPMRxFGo5Y/oPqRWVq0Vnb6veQ6fKZrOOZkhlP8AGgOA349ap0V1+k/C7xprdot3Y6DO0D8q8jpFuHqN5GRWnffBTxxp+ltfy6dCyopeSKO5RnjUdyM4P4E156Rg4qW3tp7y4jt7aGSaeQ7UjjUszH0AHJroF+HvjJhkeF9XxjPNo4/pWHfafe6ZctbX9pPazr1jnjKMPwNVqK29U8MXmkeGtG1q6dFj1YzGCLB3BIyo3H6luPbB71iV2XgXxfovhSS4l1Twvba1K3+padgPLPGeCrDt1ro9U+P3i67VYdKisNJhVv3Yt4BIwXspL5X8lFdX8X7nU/8AhTPhqPxHLE+uz3aSygKEYfu5CflGOQGQHHGTXz/RXv8A8fnbRPBnhPw1b4NqufmI+Y+RGiL+khz9BXgFb3gzwvc+MfFVlotuxQTNmWXGfKjHLN+XT1JA717h4g+JOj/DtJvDXheG0VLCMR42MzSS5+bJHGR3J6nPpisX4c/E3WPFvxD/ALI1iOGbT9XgkhktwvyLtjZs4OeoBBHfPtXjWt6euk6/qOmrJ5i2l1LAH/vBGK5/Sus+HfxJk+H0OqeRpUF3PeKmySRyvllc9cdRz0GPrXrXw68b+MvFU194m1u6tbLwzp0Ll4oYAqzOFzgM2WwByTu64Hrj581/VrrXtfvtWvP9fdzNKwHRcngD2AwB7Cs2u4+EOhvrvxN0eMLJ5VpL9tldBnYIvmXPsX2L/wACrH8c64PEfjnWdWWXzYp7p/JfbtzEvyx8f7irXP16D8INCXU/F0mrT2s9za6HA2oNDBGWeaRf9XGuD94tyB32kd66rw/oPiLw3c+IPiR4stTayizmnt4ZSN0s83yqGQElVy2MHBHHpXilFfT3w88H3um/DvSIdHuYrW61hlu9RvduZFhIyqJ74wOehLEdaxfj3e6Pf6ZZ2aaxaNdWEjboUIklLEYwcfdHHPvj0r58OM8HIrqPhxqdtpHxB0e+vLlbW3jlIedxkJuUrk+3PXtXrfjnTLe7+H2vSeFfFJ12TzkutWBuVmbyhz8u3hVBGfoDzxXz3RXqXwh8HeJ/+E90jVUsLmzs4GMslxPEUV4ypBC5xncDjj1zXH+P1VfiJ4jCBgDqVwcEY6yE1ztWdPtlvdStbV5lhSaVY2lbogJAJP0r0rxn4i13wTqUGiaKkWnaZBCPskyRI73C95SxB5JycDGK4K/8Ua/qkUsN/rWoXMUpy8cty7I3OR8pOOoH5Vn2lrNfXkFpbpvnnkWKNc43MxwB+Zr0zX7J9P3/AA88H2jzzRAPrWohcNcSDllLfwQpnocDI9cluQi8N281/cW9nffbls7N7m7mhj/doVGSFJPzL0G7jJ6CudOMnHSrumahNYSyCK7ntknTypXhPzbCRkdRnp0rttE0z4c3eoWdot/rk97LIiKZ4UihZyeB8pLDn3rlvGFxf3Pi7VDqePtcdw8MijopQ7cD2GOKTwjaaVfeLdLttcuFt9MkuFFxIxIG30JHQHpntnPavZ/E/iTxjaX8Xh34b6IbbQlCxW91YWokExP3m8wgqoyfvHngktzWhqfiib4ZeB7nTfEuuya54p1BWKW2/wAxbcEYBYn+EdefvHgDAJr5wdzJIztjLHJwMV2ng+SS08D+N721nMd0tpb24C/eMUkwEn4YAB+tXvC+neDdD0ODUvG9ne3cmokmztINyFYgSPMY7l4JzjntXIeJtKj0TxPqelxOzx2tw8SM3UqDxn3xWUCQQQcEV9BeC9U8RfEzwc+myapeaKNLRV/tOHISfA6McjkLgnnvmqHiiXWPhLpqvp+pzaq2rjEl/c5deAegJIzg9yehrw12LuznGWOTivYPg3r/AIO8JaXqet6xfLHrJfyIYzGzFYiAcqB1yc5/3R6891b6tpvxF1c6foPj/XEvIojNsigMMRAxk/dB6leCa+dtc1a+1jWLq81C9a7uJHw0zfxAcDHtgCs2uj8CeFZvGfi+x0ePesLt5lzKoP7qFeXbODg44BPG5lHerfxH8UweKfFTyWEaRaRYxiy06NAQogTIUgEDGeTjGQCB2rkaK9R+CPgn/hJfFY1a6wNN0lllcEf62Xkov0GNx+gHfNZHxX8bP408ZTSQyE6bZkwWa54Kg/M/1Y8/QKO1cLQOvNfRf7Q9pbaj4Q0DXILpJI45zHEVORIsqbtwP/bMfnXzpXq/7PlxDD8Q7mN5ESa402WO3DfxPuRsD/gKsfoDWVJ8GPiFNqEqPopdi5LTtdRbWyfvZLZPr61sxW2lfB2C7nk1S21TxlJE9vDBaHdFYbsguzEctjHGAecYxzXkrMzuzuxZmOSSckmtrwj4ZvPF/iaz0WzyrTv+8l27hFGOWcjjoO2Rk4Heu3+K3jCNXi8CeHnMHh/RgLeQINpuZk4YueMgNnty25uflx5bRXrnwymtvDnwz8beKvPC3/kjT7byiBNCzjhhz0LOh/7ZH0ryOitHSNe1bQJpJtJ1G5spZU2O8EhQsvpxXd6Jq2p658LPiFPqeoXd9Oi6cFe5maRgvnsSAWJwK8zor6EvfEt/8NvgppNvNcebr+qwFbXOT9mhIHIx3VSoH+03cLXz4SWJJJJPUmkrovBPhdfGPiOPRzqMNjJLG7RvKMh2A4QDI5P9DXo58Nn4Q+CPEMuuXtpNrOt2p0+0s7eQnEbZDucgdM56YGAM5bjxatXw1qsOh+JtM1S4thcxWlykzRH+IKc/nXuuv/tEaXBCj+HtOuLm7dQG+2Dy44/wUksfpge9c5b/ABM8J+M5Zv8AhNvDVql4U2xXdsMZHPBJIIx25PX8/KNcNmdbu/sBU2nmHytucbfxra8LaR4W1exuYta159IvkbMMjwmSJ17g45zV3x9rWi3VhoGhaJdS39vpFu8b38iFPPdyCdqnkKMYGf6ZPEVPY3k+nahbX1s2y4tpVmibGcMpBBx9RXqWteJfBHi2CbVLrUdT0O5nVft+m2lv5hvHB4IfIX/vrHTpmuL1jxTFNp76PoOnrpWkMwMihy890QODNJ35yQoAUZ6HANYumSWsWq2kl8jPaLMhmVTglM8/pXca38Nr/VNWlvfCH2fVtKum82I288QeHPJjdMjaRntxjFMg+HGo+G5otU8W3dpottA6yiF7hZLm4wc4ijQkk5wCSRjOe1cfrmqPreu32qSIEa6neXYDnbk5x+FangPwuvjHxfZ6K9z9nSYOzOBk4VS2B7nFfTFpeWvw58IXFkbVrSysEby7i4lX98xzjHJyScY/yK+TNRvJdQ1G4vJpHklmkLs8jFmYk9STVatvwrrw0DWPOnh+0WFxGbe9t/8AnrC2NwHuMAj3Ar0HXdR8MX13Y+K/7Vs7mw0yCOCw0PYVncocqsoxgLk5JGRgYGa8s1C+uNU1G5v7uTzLm5laWVsYyzHJ/nW34B8PweKfHGlaNdSFLe4lJlI6lVUsVHoSFxntmu98cw+O/EuqHwxo/hfUtP8AD1o/kWtnDbMkThTxI8mApyeeTgcdTknL8f3C+H/A2heBLi5W61ayma6vTHJvSAtu2RA+u1sn0/GvMq2fDHhnUPFurjS9MEbXbIXVZHCggdeTXrlxbWHwN8IXkQvYbrxrq8PloYRkWsXdgTyB3yfvMF4wpx4VRXqjRx+APg6sqtH/AG94tG3OAWhsccgHkfNkZ6ff9Uryuiium0Hx5rXhvw7qWiac0CW2o/652jy4yu07TnjI9jXM0UUpYkAEnA6CkpyO8UiyRsyOpBVlOCD6g1rXHizxJd2htLnxBqs1sRtMMl7IyEem0nFY9FdH4P8AGuqeCL27vNJS2Nxc25ty88ZYoCQcrgjnIHXI9q553aR2d2LMxySTkk02iiiiirMF/d2ttc20FzLHBdKFnjVsLIAcjcO+DyKrUVra74m1nxLJbSaxfyXbW0Qhh3gAIo9gOvqep71k0UA4ORT5JZJpDJK7O7dWY5J/GmUUUUUUUUUUUUUUVYsL+60y+hvbKd4LmFg0ciHkGrOr6/q+vTibVtSubxx93zpCwX6DoPwrOoooop8M0tvMs0Mjxyocq6MQQfYit+fx74vuYDBL4n1do2GCv2xxkYxg88iudJJOSck0Va07UrzSL+O+sLiW3uY87JYnKsuRg4I56Go7u8ub+6kury4luLiQ5eWZy7sfUk8moaKUszABmJAGBk9BSUV//9k=",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAdwAAAAcCAAAAADuV1EpAAAP/klEQVR4Ae1aeXwUVZ6vV32l00m6EyCIwYjIEQ4lgItIUAQViRyBERCRD16j684BiLIKjjqzKwwwAxlHWEWyCKxCUOQakEvlCGeAABESQg4gIVcn6aSTPqvee+yr6u6qV0fHFdzPrp8P9UfX7/y+936/37u6m2FuP7cjcDsCv7wIgF9el2+6xwZ0066CI7Amm4aa1twShuRstbIuLHG3iVuIAEjqn3894M29lVIG/WftmxF3KwjUANIeWdGfpfj/c7Jn5XeW6J0AlntM0bXRNX38J37mURo72DUpiPVAjBFaGxu9Iz+iicn4dNsLb8eqkVljXPfBHawmtfxH0EZdyO/+IyY3rba+Z/3Jvl18GO+L5mVf2goRt+Em0tSEy3VB7dMz2kEzWOKiaI1ZpVXn56ti3anGT3ILEfJ31m1MLdSWadqRE38en2IDSmBTryUXr7qqq08te62fQQ0S5lmDtqfGsuC56BMFGEzWuE4Zn379amqcMQpqdDH4R9P9UbRsTLTizsEY5+p7sRM8iGgRv8Wsb8AAozIqkhnA+A8SQxGG3YFjJgaw+m4J6w5efCNRT5dUykGeb7JTUISMnenevW1Vdj7CzUqFHmeIcYxNVymmVlU+ajGwqlGY5zQEOE9bQ/H+TX/7Zmqcykdgrel/PXGlcnOGxajI8GTkm6FjLYgMPSd+VVLR1BxEZKXh8h9LUowSGDp2e6BX2uiM+5KsrKrSwoCdvLxO3YAuuzwkRZhbqehI2AecJ6oJYUb5sn0SFOYF7w8G99AaYOsxemCf3mNfP9mGoHM6rZLomRjbJIYihkIPmWSd3x1PySSy8642b8BV/bK2rB1OBLkA9A6SbMMEECaW+TTGI9QaiWfNFtaY8E87Nl44elE1wWc4v00idoA1KKZhakFd1Z753WJMBkPHj7akqeeuYVBxkBe2A771cF+6r+AaPKAXY9KCbdLOFh5CyEPO7+M4z0JFotiJxW4/F+T5gKfx4PgUVa2FRrIYu6QhESLWYOo6PttFkhd+ziqqJWRqIQsbTqPdwjT7npMMAHPNP1zyQl9XyiJpg5ushAIm+UCbKI1M+jEvMxRVAV8k3KOuw5QsQt7ZAH11TX5P45iIJPJmD/P8iTdWuBp7RCTK9108alJKJA7E/fnztesPlNfXnDqdRxYN+tngfDYUEQPLAjkr4N4tuyeF16qOO04OUkbNvNMrDJ33e/0Iln3QhcrS/b6WLgQf0AkPt9d1YzPHB10XC/avnDCr3OfbdQeFyn7MIbK5wEDzdae/PsNuVS1QAgbL48wwlvACwyYWQjEHQpH5Sab4eEodJnuQjgaphiSLtTxCwaC3qWzl7AYeH5PkDNO9WRgdAXU7gwg+RmlkEuMKmZGpzrBNCGKmv0iWRShDGeSO/KHEw/vWq/sT851nW3L8KyePUJGMuAlvNg+jKDtHzIKLV0/8Y++qKXaj0aIYv9Hp+ZcwCDAaLHJyGcevIqFne+//1450S8wEcTVrzp7w+Pgtfr714MPUdpflXSb2Jj4xhkIT3dnhFZz3zLB+ibHmmLRt1xtODKJQHya5bS07v+OZh7pN2v6SA7CKfoYMd2NIeTAxLx0Xd0zo/ibVZkp4B2LYndaH6KdIli5qxcxyYu6uL992JPfZ1Kw63ELV/DM+hIMtFf9+X8LIOnhFPQwRKxnjiTqgzFm0WBBP5g9ptZkcl5fxbD4Hvb9Vl75p15ms1Ac27xys9QpJfodxB10dO7TswowecTqZZ1vg9ogLa3dQ9QRsqaEiMnR5YsnmvhEj4c0Km1jjsgGxZGO0ZvO4dZaDGv+Cy2KmY+cfaiy9k3YjdIfs+ksL+loBMHV9ZuvOvzyYJOtTXQg1f7zoubs6xtpSUqyygqJsGC+gWMYw4gjvd+3ICh/zUyHmk2l9iM4lvb1PK54AMbfn1Q+nZ6V3sBvYLzGeLtuY/su5eawY/2mQ19+uczBW71SCvwkhoUhAAxoqw4UpYy260v/ur/0Y1mba44yKU4XxrXNbr5Q3bYzReIUFQzDqpqvrsO1iPyr8lE0t3BziusaZ47ummuSCAjFx1k73vvHF/tMHdq0eTmeJHUWOQ3vDy8dIiFFjT6oomFWFIuKTl4I8t5GaDaI0KXvXZ4v7xPd64j/eWzSoYyfZz7iJbH4n1pfmHytcsHxaB/3elmBE9Z2Q5vhHusmHsiyy/MoDkCzJEou0S7ydHMEKx7zzp4c7i5t7igfvlTwYED/YIvbN3IZr9QsNKReRiO8RXCCQmZDTDuHXCK4evJEnm0fB3+ePSkyKjTiRN3uaE7aBht6UjCbBJoQm0wKJ/mPTbP2lfA2qF8Nv+rdv+lgfmZPeUXJhgMl2z/zL3kBrVcGm3CGKJJGpcC1sOZbsc/7J9DhAebEA4jjmvl5deqW/DChSILHvsYb6skbnpQ+Gmxk5t0xSC8buKnIZgIEmd1FPlVuItWNMnXqENHLcYdnSmEMEUObDlJA73KYR51gZ7gBu7FbScEPQNfDMANnmRtsZIIoLbcw+3XMTC5gS2V6iwBA4glwKHl8EtmBJGCbArBve/suHkfnudYxa6/dBo0mGxl8PJGbAvm64dgREbhkxmvF/q0YUeOtwbyCWEzurVHeaHrif4IMHf93p71dQa6DBJetv8Pz1zw7HF9+4y/Fk+lOnZQUDOt+4cRYLgwc7xjJM4O5GGvqGTUzZK3euX9g6cHX3C5QjIW80ewsGWDoywbauVRxD+ZnrYrzfekbyhc5B9zh6jixVuoW4RsZ3XU8ekoHZpPTztHoD+QqD04gdQYw23y1vVWyz3lcS68ksm2ez6KyVpDF6NYvgPwGv2R3H+co2eG9EJL3NRX5vkZt0ki/47i89QiuDpGTu2FB5+ZvzQfScLIpQIDnj/UKEXEkRAf2OXXi09KNEWhKmc7wLSSZif7cvJwUwlt/s1znCEEv2vm8r6LnJxJWUf/LPfcjV6gfS02PCDKKfc2Vkmpt2vEkiYrw0htaItD2XR7DN4z7WR6lKzx6TQoqaOO3BSKULWVqjXCxDWjbJRZa1ISGG/jSSxSVAC0T6C4Sqhlht0rjMLlyoMZpJxnfOEm9PEOtVoSYHOa2QYQ4jv4usu3VNbdqlMr7Y0+xsxNj5Wm9HrEGx44rQBI89ir2KZgQm66Dzwg9tGPsGh1Ss4ptEU8aHx0tGa5wYJs+XCZge+dXvClG969Q+MbiinaLjYKvnHoV3h2Gz84rOXCajgCsVCoF5+1o/hn3fKRxDTDWj1OqUEn+gdcWSitbmbEUbZGARvhxDnanCMEHlHVeJDGI+J5ko04aUAeSa61VXIFOE4JtmRsotk+HWHpz6EsSiRHL9j/RMbtJEYi1zMvVSAPEQlWXWu0mFqx7j9uqqQg+q703unCqVxD6PkZyEkPQogWwWKgZVrE8yssCYYJfMBQIk/KlkmxYQVDS8b7Ov2H+v0Pmk46X9Ik4JjggVes+/qpqBxl7rAsIlxDlWi9qlfqG9Q17N/cT1kYaBSiDmZS8K7oxhLU9XBo9ohy8axyDsUXmJ7BCMtY1JhqAfuXjz6RJPEa1kN9e0FcB+B2Vj2s59rykAMsLL6lCHfVZgPJJyl0jDuOlpd5kY1h3QFgR4YeNZD0azJWMdYhziVd0YTS7j/oIrwnELthxeOWdKWnqyVMJiI2BAxWGVE0FmT9Z8P21MlpmQpiklpZ2lxjI/UnZtRFGmpBMJQ49Csthh72RN0Ih66/FH51XWJJP1vvSSEodZgHDgt0KKwKrASZUu3AKoxnh5mFa8EH2cVWhE5m+kP1v1sg/qySyT6jbsaIK4lu77c01tj6sxp+jt1WEjcvzTa0uC8OrN67gXAljnaCc5EeI/oUuJC14PwkDbFSfHke+CPS2utqovPxbmjfiYetttBgY86VeMJawc0NByfbgRsJYHyppOyjUa+9a79CmdYX5fcEfYQ3yBuOEfuklZ86ceig13hU6UvfBA9pnPiGJG4EHajWEWBFCrIyT6OHBAqYtwrxBcuSsRKcOswUhbnrLaQdatQEeZlykT+cLF96IyZIyFQ6cpuKf9gQPqRllSLSovCdTQTt5FI6g9wpHTxBcEkmpVgiMEa48jn6YWuJ+WEnq4j+f9fveFeQ+lZ712tq7RU5ydGcGwzcuZlZrg2AH30PEPAxgOtrj+mjns92tbGrtTmP3Plr5KWxtX50TgBCu2/5Ltp2oavK11u4bFh38OpM2ZcdfrDw0C7AMHVymD09PHNUSqpiqwgWpRIIFZAJlCimaaSiNqEX5XRxwRmch8x3MjnOIdR1Y0WKTcpxgrB4sjXWFiiyBcS49Q9HdGAyTaRWQzVjSiYgDUu9B0JxvEFJVlhH25YKkVWM8gTYGy77i5gHOZeAphO524XnvuU+mr/LjFtXUle6p8/gERGPptnVvZ6vH6q+crshCz1FVIHZztece6KJzOXa7K27b19Pmi+Sk6Zw2Spq/qcueOm1W4TjkZ2EO+erJYi88QzqNeKRPvsLOmBSRHF+nGIvQZ7UkjoiJvsJz4VYsBoKQh0kCWUOxLU8qJMPi2UE3ANq6RR/xuaRuL2GXhdn5ka8Y4JWKo9wY8VERUtDEdhqhSz1qQ7eUb9+X5EV6nMTCkTH1LqkPH+N/M+WCqhN19fV2rz1c+WOMkCkCfNTWXZ5pVyhGrm2uGhOakZdQPrZVPK/TPNLZce33pG7nFF17U/nEgZNkzt6axrUI1/pga/5JQt9i3XNzzCkzCJI+c+cc6kod6tULguyJ8VE8elnUicwLNVCwfsvX3BDTYU+YFyrSUOHhPfFUh/NqAr07RzFtyxkb6xSK41+GgsnAFIf0EITU7woqhEAdVIZFd5pLEkueaLNGnzFm/yqCgYwbPeEo8D+tb60lNk7Y0tx04tHX/cY8HwmvdlDYTIfQ59+ZXlefcHXWE8VOP5vVVujHmWpSfzjIg8fk6nhcu2MrHsrJeCHSTJs6C2RUUSFCaK7i5xLGqswYyZNOrCsKrqkSxCeKP9EI8UXCeTpOrMS5VNKFgsrgvozQWNqtF2pXyZR7r/wYh+BhyhN3js6gBlZpvv13JrD3C0eeTXaerWznYtLWbtAiEHVI85CfXlvriRYrLtBpNe5eznOJhsPUa+VcLCm5WmxP+eXLywefUrYUM8/n1+gpRDZaRuTlZJ0UhZzaxmzZoxjmfB8kvyxff1IxPcErmMZoW8tb9FE4/7T25bs0cZQ9CuCa6D0h6MUOzN0Q3v0UNOUSzultq8uJSb2PtB0PCAfsflxIYuZz8A4NMlbYld+r1LamWa10poUmEaJo8WDX1FACGz7nAqYHaBCqMdBhlG7TBsxA3WGnBT6TjxkZOF5KjaVr9IYfEkd2eov//kMC66p1J5H710x825bG0J4ZHm4LmHg+2l8F2mjNMKtk2SfefG+04tat61OOefDMjbA/UZriVcmkP+WfV/S8VXW/dCf2z9vw22C8jAv8NCqu8B3OeFNIAAAAASUVORK5CYII=",
      "text/plain": [
       "<PIL.Image.Image image mode=L size=476x28>"
      ]
     },
     "execution_count": 77,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "z = prior().sample((n_features+1,))\n",
    "print(z)\n",
    "print(z.shape)\n",
    "x = decoder(z)\n",
    "print(x)\n",
    "sample = x.sample(torch.Size([1]))\n",
    "print(torch.max(sample))\n",
    "print(torch.min(sample))\n",
    "print(sample.shape)\n",
    "y = sample.mean(dim=1)\n",
    "y = y.reshape(-1, 28)\n",
    "print(torch.max(y))\n",
    "print(torch.min(y))\n",
    "print(y.shape)\n",
    "z = y\n",
    "#to_pil_image(y)\n",
    "#print(torch.max(torch.tensor(x)))\n",
    "#print(torch.min(x))\n",
    "x = x.mean\n",
    "print(x.shape)\n",
    "x = x.reshape(-1, 28, 28)\n",
    "y = x.movedim(0, 1).reshape(28, -1)\n",
    "print(torch.max(y))\n",
    "print(torch.min(y))\n",
    "print(y.shape)\n",
    "\n",
    "to_pil_image(x.movedim(0, 1).reshape(28, -1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0588,\n",
       "         0.0588, 0.0588, 0.0588, 0.1176, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0588, 0.0588,\n",
       "         0.0588, 0.1176, 0.0588, 0.1176, 0.0588, 0.0000, 0.1176, 0.0588, 0.1765,\n",
       "         0.1176, 0.0588, 0.0588, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0588, 0.0588, 0.0588,\n",
       "         0.0588, 0.1765, 0.1176, 0.1765, 0.1765, 0.2941, 0.1765, 0.3529, 0.3529,\n",
       "         0.2353, 0.1765, 0.1176, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0588, 0.0588,\n",
       "         0.0000, 0.1765, 0.1176, 0.2941, 0.4118, 0.4706, 0.4118, 0.5294, 0.5294,\n",
       "         0.2941, 0.2941, 0.2353, 0.1176, 0.0588, 0.0000, 0.0000, 0.0588, 0.0588,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0588, 0.1765, 0.2941, 0.4118, 0.5294, 0.5882, 0.5882, 0.5882, 0.5882,\n",
       "         0.4706, 0.4706, 0.2941, 0.4118, 0.1765, 0.1176, 0.0588, 0.0588, 0.0000,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0588, 0.0000, 0.0000, 0.0000, 0.1765,\n",
       "         0.2941, 0.2941, 0.4118, 0.5294, 0.5294, 0.5882, 0.7059, 0.5882, 0.5882,\n",
       "         0.4706, 0.3529, 0.4118, 0.4118, 0.2353, 0.1176, 0.0000, 0.0588, 0.0588,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0588, 0.0588, 0.0588, 0.0588, 0.0588, 0.1765,\n",
       "         0.3529, 0.4118, 0.5882, 0.5294, 0.4118, 0.4706, 0.5882, 0.6471, 0.4706,\n",
       "         0.4118, 0.3529, 0.4706, 0.2353, 0.1765, 0.1765, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0588, 0.0588, 0.1176, 0.2941,\n",
       "         0.4118, 0.5294, 0.5294, 0.4118, 0.4706, 0.2941, 0.2941, 0.4706, 0.1765,\n",
       "         0.2941, 0.3529, 0.2353, 0.1765, 0.1176, 0.1176, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0588, 0.0588, 0.0588, 0.0000, 0.1765, 0.2941,\n",
       "         0.3529, 0.4118, 0.1765, 0.2941, 0.2941, 0.2941, 0.5882, 0.4118, 0.3529,\n",
       "         0.4706, 0.3529, 0.2353, 0.0588, 0.1176, 0.0588, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1176, 0.1765, 0.2353,\n",
       "         0.4706, 0.3529, 0.2353, 0.3529, 0.3529, 0.4118, 0.5882, 0.5294, 0.5882,\n",
       "         0.5294, 0.3529, 0.2941, 0.1176, 0.1176, 0.0588, 0.0588, 0.0000, 0.0000,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1176, 0.2353, 0.2941,\n",
       "         0.4118, 0.2941, 0.1765, 0.2353, 0.3529, 0.5882, 0.5294, 0.7059, 0.5294,\n",
       "         0.5294, 0.4706, 0.2353, 0.0000, 0.1176, 0.0588, 0.0588, 0.0000, 0.0000,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0588, 0.1765, 0.2941,\n",
       "         0.3529, 0.1176, 0.1176, 0.1765, 0.3529, 0.6471, 0.6471, 0.5294, 0.4118,\n",
       "         0.4118, 0.4118, 0.2353, 0.0588, 0.1176, 0.0588, 0.0588, 0.0000, 0.0000,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0588, 0.0588, 0.1176, 0.2941,\n",
       "         0.2941, 0.1765, 0.2941, 0.4118, 0.5882, 0.6471, 0.5294, 0.4118, 0.4706,\n",
       "         0.2941, 0.1765, 0.1176, 0.1176, 0.1176, 0.0588, 0.0588, 0.0000, 0.0000,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0588, 0.1765, 0.1765,\n",
       "         0.2353, 0.2941, 0.2941, 0.4118, 0.6471, 0.5882, 0.4118, 0.2353, 0.3529,\n",
       "         0.2941, 0.2353, 0.1765, 0.0588, 0.0588, 0.0588, 0.0588, 0.0000, 0.0000,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0588, 0.0588, 0.0588, 0.0588, 0.1765, 0.1765,\n",
       "         0.1765, 0.3529, 0.2941, 0.2941, 0.5294, 0.4706, 0.4706, 0.2353, 0.3529,\n",
       "         0.2353, 0.1765, 0.2941, 0.1176, 0.0588, 0.1176, 0.0588, 0.0588, 0.0000,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0588, 0.0588, 0.1765, 0.1765,\n",
       "         0.2941, 0.3529, 0.3529, 0.4118, 0.5294, 0.5294, 0.3529, 0.3529, 0.3529,\n",
       "         0.2941, 0.1765, 0.1765, 0.1176, 0.1176, 0.0588, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0588, 0.1176, 0.1176, 0.1765, 0.2353,\n",
       "         0.1176, 0.2353, 0.2353, 0.4118, 0.4118, 0.4118, 0.2353, 0.3529, 0.2941,\n",
       "         0.1765, 0.1765, 0.1765, 0.1176, 0.1176, 0.0588, 0.0588, 0.0588, 0.0000,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0588, 0.0588, 0.1765, 0.2941, 0.2353,\n",
       "         0.2941, 0.4118, 0.4706, 0.5882, 0.5882, 0.4118, 0.2353, 0.3529, 0.3529,\n",
       "         0.2353, 0.1765, 0.1176, 0.1176, 0.0588, 0.0588, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0588, 0.1176, 0.1176, 0.2353, 0.2353,\n",
       "         0.1765, 0.3529, 0.3529, 0.5882, 0.4706, 0.4118, 0.2941, 0.4706, 0.2941,\n",
       "         0.1765, 0.1176, 0.0588, 0.0588, 0.0588, 0.0000, 0.0000, 0.0588, 0.0000,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0588, 0.1765, 0.2353, 0.1765,\n",
       "         0.4118, 0.4118, 0.5882, 0.4706, 0.5294, 0.5294, 0.4118, 0.3529, 0.2353,\n",
       "         0.1176, 0.0588, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0588, 0.2353, 0.2941,\n",
       "         0.4118, 0.5294, 0.6471, 0.5294, 0.4706, 0.4118, 0.2941, 0.2941, 0.1176,\n",
       "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1176, 0.1765, 0.2353,\n",
       "         0.2941, 0.3529, 0.3529, 0.2353, 0.2353, 0.1765, 0.0588, 0.0588, 0.0588,\n",
       "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0588, 0.1176,\n",
       "         0.0588, 0.1176, 0.1176, 0.1176, 0.0588, 0.0588, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0588, 0.0588, 0.1176, 0.0588, 0.0588, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000]], device='cuda:0')"
      ]
     },
     "execution_count": 78,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "z"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APAFALAE4962tO8NXmrpmxQvjkk9Mc/4Uat4V1LSM+dEXAxuKDgf5yKxSpU4PBqezCyTpG+NrMMk9q9b0O4nkvIrCyHkrEqvJKiHIOMn5jx0wBmsnxpdvB+7S9E7SMTJ8oHAyRx/d6Y9a8zbIYgnOOKuaWqPexq7Y+YYr1TR7W4s9GluQdheMqgYYDDryd307f8A1uG8Q39uLp7dWD+XgDAzjArmGOWJ9aFdkYMrFWHQircmqXksPlPO5X3Y1TJJOScmiv/Z",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABMUlEQVR4AbWR0XHEMAhEcSb/dEAH6sAdXL3pwB2oAzqgg7xFlm/mJr+Rzz6ZFfAWm/3LOnZVZ+O539b/1/2K1r/wEa6DWt/rr/PMPfK0rKoVvcVyIxDhEVZehkzk6WnEEUlJm3GhlnVPhcaJRrLUUBq3yjoXTx+VbCfha7VZZYkN8tKnjZ9wkCi+aMkrOPKi5muOSE41r8qK02ok/tyCuIxKfobguEsbNcqCTa/tE3OTgPuUIEDWEkVeGB2WQCWsvVZZ2iRMNEoynRaP6D1pjid58BWD0mVHHwESTCZoNAanh3vTKpVW+OO+UVVtAVVE8R07sTtq7NunCJIUYAK//Ul2Jr6tTusZ6VM3K4/bSjYRTsjP1gR6NC2oglo0j7ZFGZfF5ljn35nscPOhPYf+3PwCXlOQlu/6FvIAAAAASUVORK5CYII=",
      "text/plain": [
       "<PIL.Image.Image image mode=L size=28x28>"
      ]
     },
     "execution_count": 79,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "to_pil_image(z)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "app",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
